/*
 * Decompiled with CFR 0.152.
 */
package weka.experiment;

import java.util.Enumeration;
import java.util.Hashtable;
import java.util.Vector;
import weka.core.AdditionalMeasureProducer;
import weka.core.FastVector;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.experiment.CSVResultListener;
import weka.experiment.CrossValidationResultProducer;
import weka.experiment.DatabaseUtils;
import weka.experiment.ResultListener;
import weka.experiment.ResultProducer;
import weka.experiment.Stats;

public class AveragingResultProducer
implements ResultListener,
ResultProducer,
OptionHandler,
AdditionalMeasureProducer,
RevisionHandler {
    static final long serialVersionUID = 2551284958501991352L;
    protected Instances m_Instances;
    protected ResultListener m_ResultListener = new CSVResultListener();
    protected ResultProducer m_ResultProducer = new CrossValidationResultProducer();
    protected String[] m_AdditionalMeasures = null;
    protected int m_ExpectedResultsPerAverage = 10;
    protected boolean m_CalculateStdDevs;
    protected String m_CountFieldName = "Num_" + CrossValidationResultProducer.FOLD_FIELD_NAME;
    protected String m_KeyFieldName = CrossValidationResultProducer.FOLD_FIELD_NAME;
    protected int m_KeyIndex = -1;
    protected FastVector m_Keys = new FastVector();
    protected FastVector m_Results = new FastVector();

    public String globalInfo() {
        return "Takes the results from a ResultProducer and submits the average to the result listener. Normally used with a CrossValidationResultProducer to perform n x m fold cross validation. For non-numeric result fields, the first value is used.";
    }

    protected int findKeyIndex() {
        block4: {
            this.m_KeyIndex = -1;
            try {
                if (this.m_ResultProducer == null) break block4;
                String[] keyNames = this.m_ResultProducer.getKeyNames();
                int i = 0;
                while (i < keyNames.length) {
                    if (keyNames[i].equals(this.m_KeyFieldName)) {
                        this.m_KeyIndex = i;
                        break;
                    }
                    ++i;
                }
            }
            catch (Exception exception) {
                // empty catch block
            }
        }
        return this.m_KeyIndex;
    }

    @Override
    public String[] determineColumnConstraints(ResultProducer rp) throws Exception {
        return null;
    }

    protected Object[] determineTemplate(int run) throws Exception {
        if (this.m_Instances == null) {
            throw new Exception("No Instances set");
        }
        this.m_ResultProducer.setInstances(this.m_Instances);
        this.m_Keys.removeAllElements();
        this.m_Results.removeAllElements();
        this.m_ResultProducer.doRunKeys(run);
        this.checkForMultipleDifferences();
        Object[] template = (Object[])((Object[])this.m_Keys.elementAt(0)).clone();
        template[this.m_KeyIndex] = null;
        this.checkForDuplicateKeys(template);
        return template;
    }

    @Override
    public void doRunKeys(int run) throws Exception {
        Object[] template = this.determineTemplate(run);
        Object[] newKey = new String[template.length - 1];
        System.arraycopy(template, 0, newKey, 0, this.m_KeyIndex);
        System.arraycopy(template, this.m_KeyIndex + 1, newKey, this.m_KeyIndex, template.length - this.m_KeyIndex - 1);
        this.m_ResultListener.acceptResult(this, newKey, null);
    }

    @Override
    public void doRun(int run) throws Exception {
        Object[] template = this.determineTemplate(run);
        Object[] newKey = new String[template.length - 1];
        System.arraycopy(template, 0, newKey, 0, this.m_KeyIndex);
        System.arraycopy(template, this.m_KeyIndex + 1, newKey, this.m_KeyIndex, template.length - this.m_KeyIndex - 1);
        if (this.m_ResultListener.isResultRequired(this, newKey)) {
            this.m_Keys.removeAllElements();
            this.m_Results.removeAllElements();
            this.m_ResultProducer.doRun(run);
            this.checkForMultipleDifferences();
            template = (Object[])((Object[])this.m_Keys.elementAt(0)).clone();
            template[this.m_KeyIndex] = null;
            this.checkForDuplicateKeys(template);
            this.doAverageResult(template);
        }
    }

    protected boolean matchesTemplate(Object[] template, Object[] test) {
        if (template.length != test.length) {
            return false;
        }
        int i = 0;
        while (i < test.length) {
            if (template[i] != null && !template[i].equals(test[i])) {
                return false;
            }
            ++i;
        }
        return true;
    }

    protected void doAverageResult(Object[] template) throws Exception {
        Object[] newKey = new String[template.length - 1];
        System.arraycopy(template, 0, newKey, 0, this.m_KeyIndex);
        System.arraycopy(template, this.m_KeyIndex + 1, newKey, this.m_KeyIndex, template.length - this.m_KeyIndex - 1);
        if (this.m_ResultListener.isResultRequired(this, newKey)) {
            Object[] resultTypes = this.m_ResultProducer.getResultTypes();
            Stats[] stats = new Stats[resultTypes.length];
            int i = 0;
            while (i < stats.length) {
                stats[i] = new Stats();
                ++i;
            }
            Object[] result = this.getResultTypes();
            int numMatches = 0;
            int i2 = 0;
            while (i2 < this.m_Keys.size()) {
                Object[] currentKey = (Object[])this.m_Keys.elementAt(i2);
                if (this.matchesTemplate(template, currentKey)) {
                    Object[] currentResult = (Object[])this.m_Results.elementAt(i2);
                    ++numMatches;
                    int j = 0;
                    while (j < resultTypes.length) {
                        if (resultTypes[j] instanceof Double) {
                            if (currentResult[j] == null && stats[j] != null) {
                                stats[j] = null;
                            }
                            if (stats[j] != null) {
                                double currentVal = (Double)currentResult[j];
                                stats[j].add(currentVal);
                            }
                        }
                        ++j;
                    }
                }
                ++i2;
            }
            if (numMatches != this.m_ExpectedResultsPerAverage) {
                throw new Exception("Expected " + this.m_ExpectedResultsPerAverage + " results matching key \"" + DatabaseUtils.arrayToString(template) + "\" but got " + numMatches);
            }
            result[0] = new Double(numMatches);
            Object[] currentResult = (Object[])this.m_Results.elementAt(0);
            int k = 1;
            int j = 0;
            while (j < resultTypes.length) {
                if (resultTypes[j] instanceof Double) {
                    if (stats[j] != null) {
                        stats[j].calculateDerived();
                        result[k++] = new Double(stats[j].mean);
                    } else {
                        result[k++] = null;
                    }
                    if (this.getCalculateStdDevs()) {
                        result[k++] = stats[j] != null ? new Double(stats[j].stdDev) : null;
                    }
                } else {
                    result[k++] = currentResult[j];
                }
                ++j;
            }
            this.m_ResultListener.acceptResult(this, newKey, result);
        }
    }

    protected void checkForDuplicateKeys(Object[] template) throws Exception {
        Hashtable<Object, Object> hash = new Hashtable<Object, Object>();
        int numMatches = 0;
        int i = 0;
        while (i < this.m_Keys.size()) {
            Object[] current = (Object[])this.m_Keys.elementAt(i);
            if (this.matchesTemplate(template, current)) {
                if (hash.containsKey(current[this.m_KeyIndex])) {
                    throw new Exception("Duplicate result received:" + DatabaseUtils.arrayToString(current));
                }
                ++numMatches;
                hash.put(current[this.m_KeyIndex], current[this.m_KeyIndex]);
            }
            ++i;
        }
        if (numMatches != this.m_ExpectedResultsPerAverage) {
            throw new Exception("Expected " + this.m_ExpectedResultsPerAverage + " results matching key \"" + DatabaseUtils.arrayToString(template) + "\" but got " + numMatches);
        }
    }

    protected void checkForMultipleDifferences() throws Exception {
        Object[] firstKey = (Object[])this.m_Keys.elementAt(0);
        Object[] lastKey = (Object[])this.m_Keys.elementAt(this.m_Keys.size() - 1);
        int i = 0;
        while (i < firstKey.length) {
            if (i != this.m_KeyIndex && !firstKey[i].equals(lastKey[i])) {
                throw new Exception("Keys differ on fields other than \"" + this.m_KeyFieldName + "\" -- time to implement multiple averaging");
            }
            ++i;
        }
    }

    @Override
    public void preProcess(ResultProducer rp) throws Exception {
        if (this.m_ResultListener == null) {
            throw new Exception("No ResultListener set");
        }
        this.m_ResultListener.preProcess(this);
    }

    @Override
    public void preProcess() throws Exception {
        if (this.m_ResultProducer == null) {
            throw new Exception("No ResultProducer set");
        }
        this.m_ResultProducer.setResultListener(this);
        this.findKeyIndex();
        if (this.m_KeyIndex == -1) {
            throw new Exception("No key field called " + this.m_KeyFieldName + " produced by " + this.m_ResultProducer.getClass().getName());
        }
        this.m_ResultProducer.preProcess();
    }

    @Override
    public void postProcess(ResultProducer rp) throws Exception {
        this.m_ResultListener.postProcess(this);
    }

    @Override
    public void postProcess() throws Exception {
        this.m_ResultProducer.postProcess();
    }

    @Override
    public void acceptResult(ResultProducer rp, Object[] key, Object[] result) throws Exception {
        if (this.m_ResultProducer != rp) {
            throw new Error("Unrecognized ResultProducer sending results!!");
        }
        this.m_Keys.addElement(key);
        this.m_Results.addElement(result);
    }

    @Override
    public boolean isResultRequired(ResultProducer rp, Object[] key) throws Exception {
        if (this.m_ResultProducer != rp) {
            throw new Error("Unrecognized ResultProducer sending results!!");
        }
        return true;
    }

    @Override
    public String[] getKeyNames() throws Exception {
        if (this.m_KeyIndex == -1) {
            throw new Exception("No key field called " + this.m_KeyFieldName + " produced by " + this.m_ResultProducer.getClass().getName());
        }
        String[] keyNames = this.m_ResultProducer.getKeyNames();
        String[] newKeyNames = new String[keyNames.length - 1];
        System.arraycopy(keyNames, 0, newKeyNames, 0, this.m_KeyIndex);
        System.arraycopy(keyNames, this.m_KeyIndex + 1, newKeyNames, this.m_KeyIndex, keyNames.length - this.m_KeyIndex - 1);
        return newKeyNames;
    }

    @Override
    public Object[] getKeyTypes() throws Exception {
        if (this.m_KeyIndex == -1) {
            throw new Exception("No key field called " + this.m_KeyFieldName + " produced by " + this.m_ResultProducer.getClass().getName());
        }
        Object[] keyTypes = this.m_ResultProducer.getKeyTypes();
        Object[] newKeyTypes = new String[keyTypes.length - 1];
        System.arraycopy(keyTypes, 0, newKeyTypes, 0, this.m_KeyIndex);
        System.arraycopy(keyTypes, this.m_KeyIndex + 1, newKeyTypes, this.m_KeyIndex, keyTypes.length - this.m_KeyIndex - 1);
        return newKeyTypes;
    }

    @Override
    public String[] getResultNames() throws Exception {
        String[] resultNames = this.m_ResultProducer.getResultNames();
        if (this.getCalculateStdDevs()) {
            Object[] resultTypes = this.m_ResultProducer.getResultTypes();
            int numNumeric = 0;
            int i = 0;
            while (i < resultTypes.length) {
                if (resultTypes[i] instanceof Double) {
                    ++numNumeric;
                }
                ++i;
            }
            String[] newResultNames = new String[resultNames.length + 1 + numNumeric];
            newResultNames[0] = this.m_CountFieldName;
            int j = 1;
            int i2 = 0;
            while (i2 < resultNames.length) {
                newResultNames[j++] = "Avg_" + resultNames[i2];
                if (resultTypes[i2] instanceof Double) {
                    newResultNames[j++] = "Dev_" + resultNames[i2];
                }
                ++i2;
            }
            return newResultNames;
        }
        String[] newResultNames = new String[resultNames.length + 1];
        newResultNames[0] = this.m_CountFieldName;
        System.arraycopy(resultNames, 0, newResultNames, 1, resultNames.length);
        return newResultNames;
    }

    @Override
    public Object[] getResultTypes() throws Exception {
        Object[] resultTypes = this.m_ResultProducer.getResultTypes();
        if (this.getCalculateStdDevs()) {
            int numNumeric = 0;
            int i = 0;
            while (i < resultTypes.length) {
                if (resultTypes[i] instanceof Double) {
                    ++numNumeric;
                }
                ++i;
            }
            Object[] newResultTypes = new Object[resultTypes.length + 1 + numNumeric];
            newResultTypes[0] = new Double(0.0);
            int j = 1;
            int i2 = 0;
            while (i2 < resultTypes.length) {
                newResultTypes[j++] = resultTypes[i2];
                if (resultTypes[i2] instanceof Double) {
                    newResultTypes[j++] = new Double(0.0);
                }
                ++i2;
            }
            return newResultTypes;
        }
        Object[] newResultTypes = new Object[resultTypes.length + 1];
        newResultTypes[0] = new Double(0.0);
        System.arraycopy(resultTypes, 0, newResultTypes, 1, resultTypes.length);
        return newResultTypes;
    }

    @Override
    public String getCompatibilityState() {
        String result = " -X " + this.getExpectedResultsPerAverage() + " ";
        if (this.getCalculateStdDevs()) {
            result = String.valueOf(result) + "-S ";
        }
        if (this.m_ResultProducer == null) {
            result = String.valueOf(result) + "<null ResultProducer>";
        } else {
            result = String.valueOf(result) + "-W " + this.m_ResultProducer.getClass().getName();
            result = String.valueOf(result) + " -- " + this.m_ResultProducer.getCompatibilityState();
        }
        return result.trim();
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(2);
        newVector.addElement(new Option("\tThe name of the field to average over.\n\t(default \"Fold\")", "F", 1, "-F <field name>"));
        newVector.addElement(new Option("\tThe number of results expected per average.\n\t(default 10)", "X", 1, "-X <num results>"));
        newVector.addElement(new Option("\tCalculate standard deviations.\n\t(default only averages)", "S", 0, "-S"));
        newVector.addElement(new Option("\tThe full class name of a ResultProducer.\n\teg: weka.experiment.CrossValidationResultProducer", "W", 1, "-W <class name>"));
        if (this.m_ResultProducer != null && this.m_ResultProducer instanceof OptionHandler) {
            newVector.addElement(new Option("", "", 0, "\nOptions specific to result producer " + this.m_ResultProducer.getClass().getName() + ":"));
            Enumeration enu = ((OptionHandler)((Object)this.m_ResultProducer)).listOptions();
            while (enu.hasMoreElements()) {
                newVector.addElement((Option)enu.nextElement());
            }
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String keyFieldName = Utils.getOption('F', options);
        if (keyFieldName.length() != 0) {
            this.setKeyFieldName(keyFieldName);
        } else {
            this.setKeyFieldName(CrossValidationResultProducer.FOLD_FIELD_NAME);
        }
        String numResults = Utils.getOption('X', options);
        if (numResults.length() != 0) {
            this.setExpectedResultsPerAverage(Integer.parseInt(numResults));
        } else {
            this.setExpectedResultsPerAverage(10);
        }
        this.setCalculateStdDevs(Utils.getFlag('S', options));
        String rpName = Utils.getOption('W', options);
        if (rpName.length() == 0) {
            throw new Exception("A ResultProducer must be specified with the -W option.");
        }
        this.setResultProducer((ResultProducer)Utils.forName(ResultProducer.class, rpName, null));
        if (this.getResultProducer() instanceof OptionHandler) {
            ((OptionHandler)((Object)this.getResultProducer())).setOptions(Utils.partitionOptions(options));
        }
    }

    @Override
    public String[] getOptions() {
        String[] seOptions = new String[]{};
        if (this.m_ResultProducer != null && this.m_ResultProducer instanceof OptionHandler) {
            seOptions = ((OptionHandler)((Object)this.m_ResultProducer)).getOptions();
        }
        String[] options = new String[seOptions.length + 8];
        int current = 0;
        options[current++] = "-F";
        options[current++] = this.getKeyFieldName();
        options[current++] = "-X";
        options[current++] = "" + this.getExpectedResultsPerAverage();
        if (this.getCalculateStdDevs()) {
            options[current++] = "-S";
        }
        if (this.getResultProducer() != null) {
            options[current++] = "-W";
            options[current++] = this.getResultProducer().getClass().getName();
        }
        options[current++] = "--";
        System.arraycopy(seOptions, 0, options, current, seOptions.length);
        current += seOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    @Override
    public void setAdditionalMeasures(String[] additionalMeasures) {
        this.m_AdditionalMeasures = additionalMeasures;
        if (this.m_ResultProducer != null) {
            System.err.println("AveragingResultProducer: setting additional measures for ResultProducer");
            this.m_ResultProducer.setAdditionalMeasures(this.m_AdditionalMeasures);
        }
    }

    @Override
    public Enumeration enumerateMeasures() {
        Vector<String> newVector = new Vector<String>();
        if (this.m_ResultProducer instanceof AdditionalMeasureProducer) {
            Enumeration en = ((AdditionalMeasureProducer)((Object)this.m_ResultProducer)).enumerateMeasures();
            while (en.hasMoreElements()) {
                String mname = (String)en.nextElement();
                newVector.addElement(mname);
            }
        }
        return newVector.elements();
    }

    @Override
    public double getMeasure(String additionalMeasureName) {
        if (this.m_ResultProducer instanceof AdditionalMeasureProducer) {
            return ((AdditionalMeasureProducer)((Object)this.m_ResultProducer)).getMeasure(additionalMeasureName);
        }
        throw new IllegalArgumentException("AveragingResultProducer: Can't return value for : " + additionalMeasureName + ". " + this.m_ResultProducer.getClass().getName() + " " + "is not an AdditionalMeasureProducer");
    }

    @Override
    public void setInstances(Instances instances) {
        this.m_Instances = instances;
    }

    public String calculateStdDevsTipText() {
        return "Record standard deviations for each run.";
    }

    public boolean getCalculateStdDevs() {
        return this.m_CalculateStdDevs;
    }

    public void setCalculateStdDevs(boolean newCalculateStdDevs) {
        this.m_CalculateStdDevs = newCalculateStdDevs;
    }

    public String expectedResultsPerAverageTipText() {
        return "Set the expected number of results to average per run. For example if a CrossValidationResultProducer is being used (with the number of folds set to 10), then the expected number of results per run is 10.";
    }

    public int getExpectedResultsPerAverage() {
        return this.m_ExpectedResultsPerAverage;
    }

    public void setExpectedResultsPerAverage(int newExpectedResultsPerAverage) {
        this.m_ExpectedResultsPerAverage = newExpectedResultsPerAverage;
    }

    public String keyFieldNameTipText() {
        return "Set the field name that will be unique for a run.";
    }

    public String getKeyFieldName() {
        return this.m_KeyFieldName;
    }

    public void setKeyFieldName(String newKeyFieldName) {
        this.m_KeyFieldName = newKeyFieldName;
        this.m_CountFieldName = "Num_" + this.m_KeyFieldName;
        this.findKeyIndex();
    }

    @Override
    public void setResultListener(ResultListener listener) {
        this.m_ResultListener = listener;
    }

    public String resultProducerTipText() {
        return "Set the resultProducer for which results are to be averaged.";
    }

    public ResultProducer getResultProducer() {
        return this.m_ResultProducer;
    }

    public void setResultProducer(ResultProducer newResultProducer) {
        this.m_ResultProducer = newResultProducer;
        this.m_ResultProducer.setResultListener(this);
        this.findKeyIndex();
    }

    public String toString() {
        String result = "AveragingResultProducer: ";
        result = String.valueOf(result) + this.getCompatibilityState();
        result = this.m_Instances == null ? String.valueOf(result) + ": <null Instances>" : String.valueOf(result) + ": " + Utils.backQuoteChars(this.m_Instances.relationName());
        return result;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8034 $");
    }
}

